# Copyright 2019 DeepMind Technologies Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Implementation of exploitability descent.

See "Computing Approximate Equilibria in Sequential Adversarial Games by
Exploitability Descent" https://arxiv.org/abs/1903.05614

The exploitability descent algorithm solves a game by repeatedly performing
the following update:

1. Construct a (deterministic) best response to our current strategy
2. Compute the value of every action in every state when playing our current
   strategy vs the best response.
3. Update our current strategy to do better vs the current best response
   by performing a policy-gradient update.

This module provides a function that returns a loss for network training, and
a Solver class that uses this loss in a tabular Exploitability Descent.

The code can be used either for a tabular exploitability descent algorithm,
as demonstrated by exploitability_descent_test, or for a neural network policy,
as in ../examples/exploitability_descent.py.

Additionally, for a minibatch version of the algorithm (which samples
uniformly across all states in the game to generate a minibatch), see the
minibatch_loss method.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import tensorflow.compat.v1 as tf

from open_spiel.python import policy
from open_spiel.python.algorithms import action_value_vs_best_response
from open_spiel.python.algorithms import masked_softmax

# Temporary disabling of v2 behavior until code is updated.
tf.disable_v2_behavior()

_NUM_PLAYERS = 2


def _create_policy_evaluator(tabular_policy, q_value_calculator):
  """Creates a function to evaluate our policy."""

  def evaluate_policy(policy_values):
    """Evaluates a tabular policy; intended to be used as a tf.py_function."""
    tabular_policy.action_probability_array = policy_values
    evaluations = [
        q_value_calculator(player, tabular_policy,
                           tabular_policy.states_per_player[player])
        for player in range(_NUM_PLAYERS)
    ]
    nash_conv = np.float64(sum([evaluations[p].exploitability for p in [0, 1]]))

    q_values = np.concatenate(
        [np.array(evaluations[p].values_vs_br, np.float64) for p in [0, 1]])
    cf_reach_probabilities = np.concatenate([
        np.array(evaluations[p].counterfactual_reach_probs_vs_br, np.float64)
        for p in [0, 1]
    ])
    return nash_conv, q_values, cf_reach_probabilities

  return evaluate_policy


class LossCalculator(object):
  """Computes the exploitability descent loss for a two-player game."""

  def __init__(self, game):
    """Initializes a loss calculation for the given game."""
    if game.num_players() != _NUM_PLAYERS:
      raise ValueError("Game {} does not have {} players.".format(
          game, _NUM_PLAYERS))
    self.tabular_policy = policy.TabularPolicy(game)
    self.q_value_calculator = action_value_vs_best_response.Calculator(game)

  def masked_softmax(self, logits):
    """Safe masked softmax."""
    return masked_softmax.tf_masked_softmax(
        logits, self.tabular_policy.legal_actions_mask)

  def loss(self, policy_values):
    """Returns the exploitability descent loss given a policy."""

    evaluate_policy = _create_policy_evaluator(self.tabular_policy,
                                               self.q_value_calculator)
    nash_conv, q_values, cf_reach_probabilities = tf.py_func(
        evaluate_policy, [policy_values], [tf.float64, tf.float64, tf.float64])
    baseline = tf.reduce_sum(policy_values * q_values, axis=-1, keepdims=True)
    advantage = q_values - tf.stop_gradient(baseline)
    loss_per_state = -tf.reduce_sum(policy_values * advantage, axis=-1)
    return nash_conv, tf.reduce_sum(loss_per_state * cf_reach_probabilities)

  def minibatch_loss(self, policy_values, q_values, indices):
    """Returns the exploitability descent loss given a policy for a subset."""

    evaluate_policy = _create_policy_evaluator(self.tabular_policy,
                                               self.q_value_calculator)
    nash_conv, real_q_values, cf_reach_probabilities = tf.py_func(
        evaluate_policy, [policy_values], [tf.float64, tf.float64, tf.float64])
    baseline = tf.reduce_sum(policy_values * q_values, axis=-1, keepdims=True)
    advantage = q_values - baseline

    # We now select a minibatch from the data to propagate our loss on.
    policy_values = tf.gather(policy_values, indices)
    advantage = tf.gather(advantage, indices)
    cf_reach_probabilities = tf.gather(cf_reach_probabilities, indices)

    # The rest is the same as before.
    loss_per_state = -tf.reduce_sum(
        policy_values * tf.stop_gradient(advantage), axis=-1)
    q_value_loss = tf.reduce_mean((q_values - real_q_values)**2, axis=1)
    q_value_loss = tf.gather(q_value_loss, indices)
    q_value_loss = tf.reduce_sum(q_value_loss * cf_reach_probabilities)
    policy_loss = tf.reduce_sum(loss_per_state * cf_reach_probabilities)
    return nash_conv, q_value_loss, policy_loss


class Solver(object):
  """Solves a two-player game using exploitability descent."""

  def __init__(self, game):
    """Initializes a solver for the given game."""
    self._loss_calculator = LossCalculator(game)
    self._logits = tf.Variable(
        np.ones_like(
            self._loss_calculator.tabular_policy.action_probability_array,
            dtype=np.float64),
        name="logits",
        use_resource=True)
    self._tabular_policy = self._loss_calculator.masked_softmax(self._logits)
    self._nash_conv, self._loss = self._loss_calculator.loss(
        self._tabular_policy)
    self._learning_rate = tf.placeholder(tf.float64, (), name="learning_rate")
    self._optimizer = tf.train.GradientDescentOptimizer(self._learning_rate)
    self._optimizer_step = self._optimizer.minimize(self._loss)

  def step(self, session, learning_rate):
    """Takes a single exploitability descent step."""
    _, nash_conv = session.run([self._optimizer_step, self._nash_conv],
                               feed_dict={self._learning_rate: learning_rate})
    return nash_conv
